{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sympy as sp\n",
    "from sympy import *  # noqa\n",
    "\n",
    "sp.init_session()\n",
    "sp.init_printing()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Selection of algorithm:\n",
    "- `divE_cleaning` (bool: `True`, `False`): \"$\\text{div}(\\boldsymbol{E})$ cleaning\" using the additional scalar field $F$;\n",
    "\n",
    "- `divB_cleaning` (bool: `True`, `False`): \"$\\text{div}(\\boldsymbol{B})$ cleaning\" using the additional scalar field $G$;\n",
    "\n",
    "- `J_in_time` (str: `'constant'`, `'linear'`, `'quadratic'`): $\\boldsymbol{J}$ constant, linear, or quadratic in time;\n",
    "\n",
    "- `rho_in_time` (str: `'constant'`, `'linear'`, `'quadratic'`): $\\rho$ constant, linear, or quadratic in time.\n",
    "\n",
    "In the notebook, the constant, linear, and quadratic coefficients of $\\boldsymbol{J}$ and $\\rho$ are denoted with the suffixes `_c0`, `_c1`, `_c2`, respectively. However, the corresponding coefficients in the displayed equations are denoted with the prefixes $\\gamma$, $\\beta$, and $\\alpha$, respectively. For example, if $\\rho$ is quadratic in time, it will be denoted as `rho = rho_c0 + rho_c1*(t-tn) + rho_c2*(t-tn)**2` in the notebook and $\\rho(t) = \\gamma_{\\rho} + \\beta_{\\rho} (t-t_n) + \\alpha_{\\rho} (t-t_n)^2$ in the displayed equations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "divE_cleaning = True\n",
    "divB_cleaning = True\n",
    "J_in_time = \"constant\"\n",
    "rho_in_time = \"constant\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Auxiliary functions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_diag(W, D, P, invP):\n",
    "    \"\"\"\n",
    "    Check diagonalization of W as P*D*P**(-1).\n",
    "    \"\"\"\n",
    "    Wd = P * D * invP\n",
    "    for i in range(Wd.shape[0]):\n",
    "        for j in range(Wd.shape[1]):\n",
    "            Wd[i, j] = Wd[i, j].expand().simplify()\n",
    "            diff = W[i, j] - Wd[i, j]\n",
    "            diff = diff.expand().simplify()\n",
    "            assert (\n",
    "                diff == 0\n",
    "            ), f\"Diagonalization failed: W[{i},{j}] - Wd[{i},{j}] = {diff} is not zero\"\n",
    "\n",
    "\n",
    "def simple_mat(W):\n",
    "    \"\"\"\n",
    "    Simplify matrix W.\n",
    "    \"\"\"\n",
    "    for i in range(W.shape[0]):\n",
    "        for j in range(W.shape[1]):\n",
    "            W[i, j] = W[i, j].expand().simplify()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Definition of symbols:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define dimensionality parameter used throughout the notebook\n",
    "dim = 6\n",
    "if divE_cleaning:\n",
    "    dim += 1\n",
    "if divB_cleaning:\n",
    "    dim += 1\n",
    "\n",
    "# Define symbols for physical constants\n",
    "c = sp.symbols(r\"c\", real=True, positive=True)\n",
    "mu0 = sp.symbols(r\"\\mu_0\", real=True, positive=True)\n",
    "\n",
    "# Define symbols for time variables\n",
    "# (s is auxiliary variable used in integral over time)\n",
    "s = sp.symbols(r\"s\", real=True, positive=True)\n",
    "t = sp.symbols(r\"t\", real=True, positive=True)\n",
    "tn = sp.symbols(r\"t_n\", real=True, positive=True)\n",
    "dt = sp.symbols(r\"\\Delta{t}\", real=True, positive=True)\n",
    "\n",
    "# The assumption that kx, ky and kz are positive is general enough\n",
    "# and makes it easier for SymPy to perform some of the calculations\n",
    "kx = sp.symbols(r\"k_x\", real=True, positive=True)\n",
    "ky = sp.symbols(r\"k_y\", real=True, positive=True)\n",
    "kz = sp.symbols(r\"k_z\", real=True, positive=True)\n",
    "\n",
    "# Define symbols for the Cartesian components of the electric field\n",
    "Ex = sp.symbols(r\"E^x\")\n",
    "Ey = sp.symbols(r\"E^y\")\n",
    "Ez = sp.symbols(r\"E^z\")\n",
    "E = Matrix([[Ex], [Ey], [Ez]])\n",
    "\n",
    "# Define symbols for the Cartesian components of the magnetic field\n",
    "Bx = sp.symbols(r\"B^x\")\n",
    "By = sp.symbols(r\"B^y\")\n",
    "Bz = sp.symbols(r\"B^z\")\n",
    "B = Matrix([[Bx], [By], [Bz]])\n",
    "\n",
    "# Define symbol for the scalar field F used with div(E) cleaning\n",
    "if divE_cleaning:\n",
    "    F = sp.symbols(r\"F\")\n",
    "\n",
    "# Define symbol for the scalar field G used with div(B) cleaning\n",
    "if divB_cleaning:\n",
    "    G = sp.symbols(r\"G\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### First-order ODEs for $\\boldsymbol{E}$,  $\\boldsymbol{B}$, $F$ and $G$:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define first-order time derivatives of the electric field\n",
    "dEx_dt = I * c**2 * (ky * Bz - kz * By)\n",
    "dEy_dt = I * c**2 * (kz * Bx - kx * Bz)\n",
    "dEz_dt = I * c**2 * (kx * By - ky * Bx)\n",
    "\n",
    "# Define first-order time derivatives of the magnetic field\n",
    "dBx_dt = -I * (ky * Ez - kz * Ey)\n",
    "dBy_dt = -I * (kz * Ex - kx * Ez)\n",
    "dBz_dt = -I * (kx * Ey - ky * Ex)\n",
    "\n",
    "# Define first-order time derivative of the scalar field F used with div(E) cleaning,\n",
    "# and related additional terms in the first-order time derivative of the electric field\n",
    "if divE_cleaning:\n",
    "    dEx_dt += I * c**2 * F * kx\n",
    "    dEy_dt += I * c**2 * F * ky\n",
    "    dEz_dt += I * c**2 * F * kz\n",
    "    dF_dt = I * (kx * Ex + ky * Ey + kz * Ez)\n",
    "\n",
    "# Define first-order time derivative of the scalar field G used with div(B) cleaning,\n",
    "# and related additional terms in the first-order time derivative of the magnetic field\n",
    "if divB_cleaning:\n",
    "    dBx_dt += I * c**2 * G * kx\n",
    "    dBy_dt += I * c**2 * G * ky\n",
    "    dBz_dt += I * c**2 * G * kz\n",
    "    dG_dt = I * (kx * Bx + ky * By + kz * Bz)\n",
    "\n",
    "# Define array of first-order time derivatives of the electric and magnetic fields\n",
    "dE_dt = Matrix([[dEx_dt], [dEy_dt], [dEz_dt]])\n",
    "dB_dt = Matrix([[dBx_dt], [dBy_dt], [dBz_dt]])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Linear system of ODEs for $\\boldsymbol{E}$, $\\boldsymbol{B}$, $F$ and $G$:\n",
    "$$\n",
    "\\frac{\\partial}{\\partial t}\n",
    "\\begin{bmatrix}\n",
    "\\boldsymbol{E} \\\\\n",
    "\\boldsymbol{B} \\\\\n",
    "F \\\\\n",
    "G\n",
    "\\end{bmatrix}\n",
    "= M\n",
    "\\begin{bmatrix}\n",
    "\\boldsymbol{E} \\\\\n",
    "\\boldsymbol{B} \\\\\n",
    "F \\\\\n",
    "G\n",
    "\\end{bmatrix}\n",
    "-\\mu_0 c^2 \n",
    "\\begin{bmatrix}\n",
    "\\boldsymbol{J} \\\\\n",
    "\\boldsymbol{0} \\\\\n",
    "\\rho \\\\\n",
    " 0\n",
    "\\end{bmatrix}\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Define array of all fields\n",
    "fields_list = [Ex, Ey, Ez, Bx, By, Bz]\n",
    "if divE_cleaning:\n",
    "    fields_list.append(F)\n",
    "if divB_cleaning:\n",
    "    fields_list.append(G)\n",
    "EBFG = zeros(dim, 1)\n",
    "for i in range(EBFG.shape[0]):\n",
    "    EBFG[i] = fields_list[i]\n",
    "\n",
    "# Define array of first-order time derivatives of all fields\n",
    "fields_list = [dEx_dt, dEy_dt, dEz_dt, dBx_dt, dBy_dt, dBz_dt]\n",
    "if divE_cleaning:\n",
    "    fields_list.append(dF_dt)\n",
    "if divB_cleaning:\n",
    "    fields_list.append(dG_dt)\n",
    "dEBFG_dt = zeros(dim, 1)\n",
    "for i in range(dEBFG_dt.shape[0]):\n",
    "    dEBFG_dt[i] = fields_list[i]\n",
    "dEBFG_dt = dEBFG_dt.expand()\n",
    "\n",
    "# Define matrix M representing the linear system of ODEs\n",
    "M = zeros(dim)\n",
    "for i in range(M.shape[0]):\n",
    "    for j in range(M.shape[1]):\n",
    "        M[i, j] = dEBFG_dt[i].coeff(EBFG[j], 1)\n",
    "print(r\"M = \")\n",
    "display(M)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Solution of linear system of ODEs for $\\boldsymbol{E}$, $\\boldsymbol{B}$, $F$ and $G$:\n",
    "\n",
    "The solution of the linear system of ODEs above is given by the superposition of the **general solution of the homogeneous system** (denoted with the subscript \"h\"),\n",
    "\n",
    "$$\n",
    "\\begin{bmatrix}\n",
    "\\boldsymbol{E}_h(t) \\\\\n",
    "\\boldsymbol{B}_h(t) \\\\\n",
    "F_h(t) \\\\\n",
    "G_h(t)\n",
    "\\end{bmatrix}\n",
    "= e^{M (t-t_n)}\n",
    "\\begin{bmatrix}\n",
    "\\boldsymbol{E}(t_n) \\\\\n",
    "\\boldsymbol{B}(t_n) \\\\\n",
    "F(t_n) \\\\\n",
    "G(t_n)\n",
    "\\end{bmatrix} \\,,\n",
    "$$\n",
    "\n",
    "and the **particular solution of the non-homogeneous system** (denoted with the subscript \"nh\"),\n",
    "\n",
    "$$\n",
    "\\begin{bmatrix}\n",
    "\\boldsymbol{E}_{nh}(t) \\\\\n",
    "\\boldsymbol{B}_{nh}(t) \\\\\n",
    "F_{nh}(t) \\\\\n",
    "G_{nh}(t)\n",
    "\\end{bmatrix}\n",
    "= -\\mu_0 c^2 e^{M t} \\left(\\int_{t_n}^t e^{-M s}\n",
    "\\begin{bmatrix}\n",
    "\\boldsymbol{J} \\\\\n",
    "\\boldsymbol{0} \\\\\n",
    "\\rho \\\\\n",
    "0\n",
    "\\end{bmatrix}\n",
    "ds\\right) \\,.\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Diagonalization of $M = P D P^{-1}$:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "# Compute matrices of eigenvectors and eigenvalues for diagonalization of M\n",
    "P, D = M.diagonalize()\n",
    "invP = P ** (-1)\n",
    "expD = exp(D)\n",
    "check_diag(M, D, P, invP)\n",
    "print(\"P = \")\n",
    "display(P)\n",
    "print(\"D = \")\n",
    "display(D)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Diagonalization of $W_1 = M (t-t_n) = P_1 D_1 P_1^{-1}$:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "# Compute matrices of eigenvectors and eigenvalues for diagonalization of W1\n",
    "P1 = P\n",
    "D1 = D * (t - tn)\n",
    "invP1 = invP\n",
    "expD1 = exp(D1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Diagonalization of $W_2 = M t = P_2 D_2 P_2^{-1}$:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "# Compute matrices of eigenvectors and eigenvalues for diagonalization of W2\n",
    "P2 = P\n",
    "D2 = D * t\n",
    "invP2 = invP\n",
    "expD2 = exp(D2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Diagonalization of $W_3 = -M s = P_3 D_3 P_3^{-1}$:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "# Compute matrices of eigenvectors and eigenvalues for diagonalization of W3\n",
    "P3 = P\n",
    "D3 = (-1) * D * s\n",
    "invP3 = invP\n",
    "expD3 = exp(D3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### General solution (homogeneous system):\n",
    "\n",
    "$$\n",
    "\\begin{bmatrix}\n",
    "\\boldsymbol{E}_h(t) \\\\\n",
    "\\boldsymbol{B}_h(t) \\\\\n",
    "F_h(t) \\\\\n",
    "G_h(t)\n",
    "\\end{bmatrix}\n",
    "= e^{M (t-t_n)}\n",
    "\\begin{bmatrix}\n",
    "\\boldsymbol{E}(t_n) \\\\\n",
    "\\boldsymbol{B}(t_n) \\\\\n",
    "F(t_n) \\\\\n",
    "G(t_n)\n",
    "\\end{bmatrix}\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "# Compute exp(W1) = exp(M*(t-tn))\n",
    "expW1 = P1 * expD1 * invP1\n",
    "\n",
    "# Compute general solution at time t = tn+dt\n",
    "EBFG_h = expW1 * EBFG\n",
    "EBFG_h_new = EBFG_h.subs(t, tn + dt)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Definition of $\\boldsymbol{J}$ and $\\rho$:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define J\n",
    "Jx_c0 = sp.symbols(r\"\\gamma_{J_x}\", real=True)\n",
    "Jy_c0 = sp.symbols(r\"\\gamma_{J_y}\", real=True)\n",
    "Jz_c0 = sp.symbols(r\"\\gamma_{J_z}\", real=True)\n",
    "Jx = Jx_c0\n",
    "Jy = Jy_c0\n",
    "Jz = Jz_c0\n",
    "if J_in_time == \"linear\":\n",
    "    Jx_c1 = sp.symbols(r\"\\beta_{J_x}\", real=True)\n",
    "    Jy_c1 = sp.symbols(r\"\\beta_{J_y}\", real=True)\n",
    "    Jz_c1 = sp.symbols(r\"\\beta_{J_z}\", real=True)\n",
    "    Jx += Jx_c1 * (s - tn)\n",
    "    Jy += Jy_c1 * (s - tn)\n",
    "    Jz += Jz_c1 * (s - tn)\n",
    "if J_in_time == \"quadratic\":\n",
    "    Jx_c1 = sp.symbols(r\"\\beta_{J_x}\", real=True)\n",
    "    Jy_c1 = sp.symbols(r\"\\beta_{J_y}\", real=True)\n",
    "    Jz_c1 = sp.symbols(r\"\\beta_{J_z}\", real=True)\n",
    "    Jx_c2 = sp.symbols(r\"\\alpha_{J_x}\", real=True)\n",
    "    Jy_c2 = sp.symbols(r\"\\alpha_{J_y}\", real=True)\n",
    "    Jz_c2 = sp.symbols(r\"\\alpha_{J_z}\", real=True)\n",
    "    Jx += Jx_c1 * (s - tn) + Jx_c2 * (s - tn) ** 2\n",
    "    Jy += Jy_c1 * (s - tn) + Jy_c2 * (s - tn) ** 2\n",
    "    Jz += Jz_c1 * (s - tn) + Jz_c2 * (s - tn) ** 2\n",
    "\n",
    "# Define rho\n",
    "if divE_cleaning:\n",
    "    rho_c0 = sp.symbols(r\"\\gamma_{\\rho}\", real=True)\n",
    "    rho = rho_c0\n",
    "    if rho_in_time == \"linear\":\n",
    "        rho_c1 = sp.symbols(r\"\\beta_{\\rho}\", real=True)\n",
    "        rho += rho_c1 * (s - tn)\n",
    "    if rho_in_time == \"quadratic\":\n",
    "        rho_c1 = sp.symbols(r\"\\beta_{\\rho}\", real=True)\n",
    "        rho_c2 = sp.symbols(r\"\\alpha_{\\rho}\", real=True)\n",
    "        rho += rho_c1 * (s - tn) + rho_c2 * (s - tn) ** 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Particular solution (non-homogeneous system):\n",
    "\n",
    "$$\n",
    "\\begin{bmatrix}\n",
    "\\boldsymbol{E}_{nh}(t) \\\\\n",
    "\\boldsymbol{B}_{nh}(t) \\\\\n",
    "F_{nh}(t) \\\\\n",
    "G_{nh}(t)\n",
    "\\end{bmatrix}\n",
    "= -\\mu_0 c^2 e^{M t} \\left(\\int_{t_n}^t e^{-M s}\n",
    "\\begin{bmatrix}\n",
    "\\boldsymbol{J} \\\\\n",
    "\\boldsymbol{0} \\\\\n",
    "\\rho \\\\\n",
    "0\n",
    "\\end{bmatrix}\n",
    "ds\\right)\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time \n",
    "\n",
    "# Define array of source terms\n",
    "fields_list = [Jx, Jy, Jz, 0, 0, 0]\n",
    "if divE_cleaning:\n",
    "    fields_list.append(rho)\n",
    "if divB_cleaning:\n",
    "    fields_list.append(0)\n",
    "S = zeros(dim, 1)\n",
    "for i in range(S.shape[0]):\n",
    "    S[i] = -mu0 * c**2 * fields_list[i]\n",
    "\n",
    "# Compute integral of exp(W3)*S over s (assuming |k| is not zero)\n",
    "integral = zeros(dim, 1)\n",
    "tmp = expD3 * invP3 * S\n",
    "simple_mat(tmp)\n",
    "for i in range(dim):\n",
    "    r = integrate(tmp[i], (s, tn, t))\n",
    "    integral[i] = r\n",
    "\n",
    "# Compute particular solution at time t = tn+dt\n",
    "tmp = invP2 * P3\n",
    "simple_mat(tmp)\n",
    "EBFG_nh = P2 * expD2 * tmp * integral\n",
    "EBFG_nh_new = EBFG_nh.subs(t, tn + dt)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Verification of the solution:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "for i in range(EBFG.shape[0]):\n",
    "    lhs = dEBFG_dt[i] + S[i]\n",
    "    lhs = lhs.subs(s, tn)  # sources were written as functions of s\n",
    "    rhs = (EBFG_h[i] + EBFG_nh[i]).diff(t)\n",
    "    rhs = rhs.subs(t, tn)  # results were written as functions of t\n",
    "    rhs = rhs.simplify()\n",
    "    diff = lhs - rhs\n",
    "    diff = diff.simplify()\n",
    "    assert diff == 0, \"Integration of linear system of ODEs failed\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Coefficients of PSATD equations (homogeneous system):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "L = [\"Ex\", \"Ey\", \"Ez\", \"Bx\", \"By\", \"Bz\"]\n",
    "R = [\"Ex\", \"Ey\", \"Ez\", \"Bx\", \"By\", \"Bz\"]\n",
    "if divE_cleaning:\n",
    "    L.append(\"F\")\n",
    "    R.append(\"F\")\n",
    "if divB_cleaning:\n",
    "    L.append(\"G\")\n",
    "    R.append(\"G\")\n",
    "\n",
    "# Compute individual coefficients in the update equations\n",
    "coeff_h = dict()\n",
    "for i in range(dim):\n",
    "    for j in range(dim):\n",
    "        key = (L[i], R[j])\n",
    "        coeff_h[key] = (\n",
    "            EBFG_h_new[i]\n",
    "            .coeff(EBFG[j], 1)\n",
    "            .expand()\n",
    "            .simplify()\n",
    "            .rewrite(cos)\n",
    "            .trigsimp()\n",
    "            .simplify()\n",
    "        )\n",
    "        print(f\"Coefficient of {L[i]} with respect to {R[j]}:\")\n",
    "        display(coeff_h[key])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Coefficients of PSATD equations (non-homogeneous system):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "L = [\"Ex\", \"Ey\", \"Ez\", \"Bx\", \"By\", \"Bz\"]\n",
    "R = [\"Jx_c0\", \"Jy_c0\", \"Jz_c0\"]\n",
    "if J_in_time == \"linear\":\n",
    "    R.append(\"Jx_c1\")\n",
    "    R.append(\"Jy_c1\")\n",
    "    R.append(\"Jz_c1\")\n",
    "if J_in_time == \"quadratic\":\n",
    "    R.append(\"Jx_c1\")\n",
    "    R.append(\"Jy_c1\")\n",
    "    R.append(\"Jz_c1\")\n",
    "    R.append(\"Jx_c2\")\n",
    "    R.append(\"Jy_c2\")\n",
    "    R.append(\"Jz_c2\")\n",
    "if divE_cleaning:\n",
    "    L.append(\"F\")\n",
    "    R.append(\"rho_c0\")\n",
    "    if rho_in_time == \"linear\":\n",
    "        R.append(\"rho_c1\")\n",
    "    if rho_in_time == \"quadratic\":\n",
    "        R.append(\"rho_c1\")\n",
    "        R.append(\"rho_c2\")\n",
    "if divB_cleaning:\n",
    "    L.append(\"G\")\n",
    "\n",
    "cs = [Jx_c0, Jy_c0, Jz_c0]\n",
    "if J_in_time == \"linear\":\n",
    "    cs.append(Jx_c1)\n",
    "    cs.append(Jy_c1)\n",
    "    cs.append(Jz_c1)\n",
    "if J_in_time == \"quadratic\":\n",
    "    cs.append(Jx_c1)\n",
    "    cs.append(Jy_c1)\n",
    "    cs.append(Jz_c1)\n",
    "    cs.append(Jx_c2)\n",
    "    cs.append(Jy_c2)\n",
    "    cs.append(Jz_c2)\n",
    "if divE_cleaning:\n",
    "    cs.append(rho_c0)\n",
    "    if rho_in_time == \"linear\":\n",
    "        cs.append(rho_c1)\n",
    "    if rho_in_time == \"quadratic\":\n",
    "        cs.append(rho_c1)\n",
    "        cs.append(rho_c2)\n",
    "\n",
    "# Compute individual coefficients in the update equation\n",
    "coeff_nh = dict()\n",
    "for i in range(len(L)):\n",
    "    for j in range(len(R)):\n",
    "        key = (L[i], R[j])\n",
    "        coeff_nh[key] = (\n",
    "            EBFG_nh_new[i]\n",
    "            .expand()\n",
    "            .coeff(cs[j], 1)\n",
    "            .expand()\n",
    "            .simplify()\n",
    "            .rewrite(cos)\n",
    "            .trigsimp()\n",
    "            .simplify()\n",
    "        )\n",
    "        print(f\"Coefficient of {L[i]} with respect to {R[j]}:\")\n",
    "        display(coeff_nh[key])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Coefficients of PSATD equations:\n",
    "\n",
    "Display the coefficients of the update equations one by one. For example, `coeff_h[('Ex', 'By')]` displays the coefficient of $E^x$ with respect to $B^y$ (resulting from the solution of the homogeneous system, hence the `_h` suffix), while `coeff_nh[('Ex', 'Jx_c0')]` displays the coefficient of $E^x$ with respect to $\\gamma_{J_x}$ (resulting from the solution of the non-homogeneous system, hence the `_nh` suffix). Note that $\\gamma_{J_x}$ is denoted as `Jx_c0` in the notebook, as described in the beginning."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coeff_h[(\"Ex\", \"By\")]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coeff_nh[(\"Ex\", \"Jx_c0\")]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
